{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "60329565",
   "metadata": {},
   "source": [
    "## Implementing VGG from scratch\n",
    "\n",
    "VGG architecture was one of the key milestone in the development of CNN classifiers. The main contribution of VGG was the thorough investigation and evaluation of increasing depth using an architecture with very small (3×3) convolution filters. VGG was able to push the depth to 19 layers and also improved the State of the Art on ImageNet Challenge. \n",
    "The details of the paper can be found at https://arxiv.org/pdf/1409.1556.pdf\n",
    "\n",
    "Compared to the prior works, VGG architectures had 2 major changes\n",
    "\n",
    "Prior architectures often had relativel larger kernels (7x7, 11x11) in the first conv layers. Instead VGG used very small 3×3 filters throughout the whole net. \n",
    "\n",
    "Note that three 3x3 filters have the same receptive field as a single 7x7 filter. So what does replacing the 7x7 filter with three smaller filters buy? \n",
    "1. With three smaller filters, there is more non-linearity due to ReLU applied after every filter\n",
    "2. The number of parameters are reduced from $49C^2$ ($7^2 C^2$) to $27C^2$ ($3*(3^2C^2)$). A reduced number of parameters means faster learning and more robust to over-fitting.\n",
    "\n",
    "Additionally prior architectures relied on a normalization layer, Local Response Normalization (LRN). The authors showed that LRN layers did not lead to improvement in performance for Imagenet. So these layers were dropped.\n",
    "\n",
    "\n",
    "In this notebook, we will take a look at how to implement a VGG-11 network from scratch. In practice, this is seldom done. `torchvision.models` already provides ready-made implementations for all the VGG architectures. However by building the network from scratch, we will gain a deeper understanding of the architecture."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5732840",
   "metadata": {},
   "source": [
    "Regardless of the specific architecture, all VGG networks follow a common structure. The commonalities are listed\n",
    "1. All architectures work on (224, 224) sized input images.\n",
    "2. All architectures have 5 Conv Blocks. \n",
    "    1. Each Block can have multiple number of convolutional layers followed by a MaxPool layer.\n",
    "    2. All the individual convolutional layers inside a block are 3x3 kernels, with a padding of 1\n",
    "    3. The individual convolutional layers do not change the spatial resolution of the feature map\n",
    "    4. All the individual convolutional layers within a block  have the same output size features\n",
    "    5. Each Convolutional block has a MaxPool layer at the end which reduces the spatial resolution of the feature map\n",
    "3. Since each block downsamples the size by 2, at the end, the input image of size (224, 224) is reduced by a factor of $2^5$ i.e (7,7). Additionally at each block the number of features are doubled. \n",
    "4. All architectures have a classifier which comprises of 3 Fully Connected (FC) layers\n",
    "    1. The first FC takes a 512*7*7 input, converts it to a 4096 dimensional output\n",
    "    2. The second FC takes the resulting 4096 dimensional output and converts it to another 4096 dimensional output\n",
    "    3. The final FC layer converts the 4096 dimensional input to a 1000 dimensional output. 1000 being the number of classes for ImageNet."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b733a69",
   "metadata": {},
   "source": [
    "Let us first impelement the convolutional block. As mentioned earlier, each block can have multiple convolutional layers (depending on the architectures), followed by a MaxPool layer.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d3eabeda",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "\n",
    "from torch import nn\n",
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "51482180",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvBlock(nn.Module):\n",
    "    def __init__(self, in_channels, num_conv_layers, num_features):\n",
    "        super(ConvBlock, self).__init__()\n",
    "        modules = []\n",
    "        for i in range(num_conv_layers):\n",
    "            modules.extend([\n",
    "                nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1),\n",
    "                nn.ReLU(inplace=True)\n",
    "            ])\n",
    "            in_channels = num_features\n",
    "        modules.append(nn.MaxPool2d(kernel_size=2))\n",
    "        self.conv_block = nn.Sequential(*modules)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.conv_block(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f07dfe30",
   "metadata": {},
   "outputs": [],
   "source": [
    "conv_1 = ConvBlock(in_channels=3, num_conv_layers=1, num_features=64)\n",
    "\n",
    "# Let us simulate a forward pass using a dummy input tensor\n",
    "x = torch.rand([1, 3, 224, 224])\n",
    "\n",
    "conv_1_out = conv_1(x)\n",
    "\n",
    "assert conv_1_out.shape == torch.Size([1, 64, 112, 112])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c4ed49b",
   "metadata": {},
   "source": [
    "Let us now implement the conv backbone builder. This will take in a list of configurations for each of the individual convolutional block.\n",
    "\n",
    "The config is a list of size 5 (corresponding to each Conv block). Each element is a 3 tuple of the form (in_channels, num_conv_layers, num_features) corresponding to that Conv block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f8dc4636",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvBackbone(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super(ConvBackbone, self).__init__()\n",
    "        \n",
    "        self.cfg = cfg\n",
    "        self.validate_config(cfg)\n",
    "        \n",
    "        modules = []\n",
    "        for block_cfg in cfg:\n",
    "            in_channels, num_conv_layers, num_features = block_cfg\n",
    "            modules.append(ConvBlock(in_channels, num_conv_layers, num_features))\n",
    "        self.features = nn.Sequential(*modules)\n",
    "            \n",
    "    def validate_config(self, cfg):\n",
    "        assert len(cfg) == 5 # 5 Conv blocks\n",
    "        for i, block_cfg in enumerate(cfg):\n",
    "            assert type(block_cfg) == tuple and len(block_cfg) == 3\n",
    "            if i == 0:\n",
    "                assert block_cfg[0] == 3 #Input channels always has to be 3\n",
    "            else:\n",
    "                assert block_cfg[0] == cfg[i-1][-1] #num_features of previous block is the input num features to this block\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.features(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "521a4b05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let us look at the VGG 11 backbone configuration\n",
    "vgg11_cfg = [\n",
    "    (3, 1, 64),\n",
    "    (64, 1, 128),\n",
    "    (128, 2, 256),\n",
    "    (256, 2, 512),\n",
    "    (512, 2, 512)\n",
    "]\n",
    "\n",
    "vgg11_backbone = ConvBackbone(vgg11_cfg)\n",
    "# Let us simulate a forward pass using a dummy input tensor\n",
    "x = torch.rand([1, 3, 224, 224])\n",
    "\n",
    "vgg11_conv_out = vgg11_backbone(x)\n",
    "\n",
    "assert vgg11_conv_out.shape == torch.Size([1, 512, 7, 7])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d63a16f",
   "metadata": {},
   "source": [
    "Now let us implement the VGG module. As mentioned previously, the VGG module has 2 key features\n",
    "1. The Conv backbone comprising of 5 Convolutional Blocks\n",
    "2. The classifier comprising of Fully Connected Layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cd43efee",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VGG(nn.Module):\n",
    "    def __init__(self, conv_backbone, num_classes):\n",
    "        super(VGG, self).__init__()\n",
    "        self.conv_backbone = conv_backbone\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(512 * 7 * 7, 4096),\n",
    "            nn.ReLU(True),\n",
    "            nn.Dropout(),\n",
    "            nn.Linear(4096, 4096),\n",
    "            nn.ReLU(True),\n",
    "            nn.Dropout(),\n",
    "            nn.Linear(4096, num_classes)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        conv_features = self.conv_backbone(x)\n",
    "        # We need to flatten the conv features before passing it to the classifier\n",
    "        logits = self.classifier(conv_features.view(conv_features.shape[0], -1)) \n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "75817590",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We are now ready to create our VGG11. Let's say we want to use a VGG-11 classifier for a 1000 class problem\n",
    "num_classes = 1000\n",
    "vgg11 = VGG(vgg11_backbone, num_classes)\n",
    "\n",
    "x = torch.rand([1, 3, 224, 224])\n",
    "logits = vgg11(x)\n",
    "assert logits.shape == torch.Size([1, num_classes])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9753d7d6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VGG(\n",
       "  (conv_backbone): ConvBackbone(\n",
       "    (features): Sequential(\n",
       "      (0): ConvBlock(\n",
       "        (conv_block): Sequential(\n",
       "          (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "          (1): ReLU(inplace=True)\n",
       "          (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "        )\n",
       "      )\n",
       "      (1): ConvBlock(\n",
       "        (conv_block): Sequential(\n",
       "          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "          (1): ReLU(inplace=True)\n",
       "          (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "        )\n",
       "      )\n",
       "      (2): ConvBlock(\n",
       "        (conv_block): Sequential(\n",
       "          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "          (1): ReLU(inplace=True)\n",
       "          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "          (3): ReLU(inplace=True)\n",
       "          (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "        )\n",
       "      )\n",
       "      (3): ConvBlock(\n",
       "        (conv_block): Sequential(\n",
       "          (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "          (1): ReLU(inplace=True)\n",
       "          (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "          (3): ReLU(inplace=True)\n",
       "          (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "        )\n",
       "      )\n",
       "      (4): ConvBlock(\n",
       "        (conv_block): Sequential(\n",
       "          (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "          (1): ReLU(inplace=True)\n",
       "          (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "          (3): ReLU(inplace=True)\n",
       "          (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
       "    (1): ReLU(inplace=True)\n",
       "    (2): Dropout(p=0.5, inplace=False)\n",
       "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "    (4): ReLU(inplace=True)\n",
       "    (5): Dropout(p=0.5, inplace=False)\n",
       "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vgg11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "95e1d72e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 224, 224]           1,792\n",
      "              ReLU-2         [-1, 64, 224, 224]               0\n",
      "         MaxPool2d-3         [-1, 64, 112, 112]               0\n",
      "         ConvBlock-4         [-1, 64, 112, 112]               0\n",
      "            Conv2d-5        [-1, 128, 112, 112]          73,856\n",
      "              ReLU-6        [-1, 128, 112, 112]               0\n",
      "         MaxPool2d-7          [-1, 128, 56, 56]               0\n",
      "         ConvBlock-8          [-1, 128, 56, 56]               0\n",
      "            Conv2d-9          [-1, 256, 56, 56]         295,168\n",
      "             ReLU-10          [-1, 256, 56, 56]               0\n",
      "           Conv2d-11          [-1, 256, 56, 56]         590,080\n",
      "             ReLU-12          [-1, 256, 56, 56]               0\n",
      "        MaxPool2d-13          [-1, 256, 28, 28]               0\n",
      "        ConvBlock-14          [-1, 256, 28, 28]               0\n",
      "           Conv2d-15          [-1, 512, 28, 28]       1,180,160\n",
      "             ReLU-16          [-1, 512, 28, 28]               0\n",
      "           Conv2d-17          [-1, 512, 28, 28]       2,359,808\n",
      "             ReLU-18          [-1, 512, 28, 28]               0\n",
      "        MaxPool2d-19          [-1, 512, 14, 14]               0\n",
      "        ConvBlock-20          [-1, 512, 14, 14]               0\n",
      "           Conv2d-21          [-1, 512, 14, 14]       2,359,808\n",
      "             ReLU-22          [-1, 512, 14, 14]               0\n",
      "           Conv2d-23          [-1, 512, 14, 14]       2,359,808\n",
      "             ReLU-24          [-1, 512, 14, 14]               0\n",
      "        MaxPool2d-25            [-1, 512, 7, 7]               0\n",
      "        ConvBlock-26            [-1, 512, 7, 7]               0\n",
      "     ConvBackbone-27            [-1, 512, 7, 7]               0\n",
      "           Linear-28                 [-1, 4096]     102,764,544\n",
      "             ReLU-29                 [-1, 4096]               0\n",
      "          Dropout-30                 [-1, 4096]               0\n",
      "           Linear-31                 [-1, 4096]      16,781,312\n",
      "             ReLU-32                 [-1, 4096]               0\n",
      "          Dropout-33                 [-1, 4096]               0\n",
      "           Linear-34                 [-1, 1000]       4,097,000\n",
      "================================================================\n",
      "Total params: 132,863,336\n",
      "Trainable params: 132,863,336\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.57\n",
      "Forward/backward pass size (MB): 137.05\n",
      "Params size (MB): 506.83\n",
      "Estimated Total Size (MB): 644.46\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# We can now take a look at the summary to visualize the output shape, number of parameters and the layers\n",
    "summary(vgg11, input_size=(3, 224, 224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "851c8433",
   "metadata": {},
   "outputs": [],
   "source": [
    "# As an elementary check, let us compare the number of parameters between our implementation\n",
    "# and the official torchivision implementation and assert that they are equal\n",
    "\n",
    "num_vgg_params = sum(p.numel() for p in vgg11.parameters() if p.requires_grad)\n",
    "\n",
    "torch_vgg11 = torchvision.models.vgg11()\n",
    "num_torch_vgg_params = sum(p.numel() for p in torch_vgg11.parameters() if p.requires_grad)\n",
    "\n",
    "assert num_vgg_params == num_vgg_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "488bb73b",
   "metadata": {},
   "source": [
    "And voila! We have our own bare-bones implementation of VGG-11. Note that this is a barebones implementation. (We could potentially add BatchNorm in between the Conv layers. Similarly, we can apply AveragePool to the output of ConvBackbone to handle variable sized images). The purpose here is to understand and get a sense for the broad architecture."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
